import numpy as np
import pandas as pd
import os
import pdb
import time
import keras.backend as K
from tqdm import tqdm
from keras.layers import Input
from keras.preprocessing import sequence
from keras.layers import LSTM, Dense, Masking, Concatenate, concatenate, BatchNormalization, Bidirectional
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from keras.models import Sequential, Model
from keras import metrics
from sklearn.model_selection import train_test_split

#Load the training data.

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "7"

start = time.time()
data_path = "./training_data/"

data_list_readdepths = []
data_list_indexes = []
data_list_canavar_preds = []
data_list_freec_preds = []

files_list = os.listdir(data_path)

print("Loading training data to the memory...")
for filename in tqdm(files_list):
    with open(data_path+filename) as f:
        data = f.readlines()
        data = [x.strip() for x in data]

    indexes = [(int(x.split(',')[1][1:]),int(x.split(',')[2][1:])) for x in data]
    freec_preds = [x.split(',')[3][1:] for x in data]
    canavar_preds = [x.split(',')[4][1:] for x in data]
    read_depth_seqs = [x.split(',')[5:] for x in data]
    read_depth_seqs = [[y.replace('[','',1).replace(']','').replace(' ','') for y in x] for x in read_depth_seqs]
    read_depth_seqs = [[0 if not y else int(str(y)) for y in x] for x in read_depth_seqs]
    data_list_indexes.extend(indexes)
    data_list_canavar_preds.extend(canavar_preds)
    data_list_freec_preds.extend(freec_preds)
    data_list_readdepths.extend(read_depth_seqs)

end = time.time()
print("Loading of the data took ", end-start," seconds.")

data_list_canavar_preds = [x.replace(']','') for x in data_list_canavar_preds]


#convert data lists to numpy arrays
data_list_readdepths = np.asarray(data_list_readdepths)
data_list_indexes = np.asarray(data_list_indexes)
data_list_canavar_preds = np.asarray(data_list_canavar_preds)
data_list_freec_preds = np.asarray(data_list_freec_preds)


data_list_readdepths = sequence.pad_sequences(data_list_readdepths, maxlen= 192000, value = -1)
data_list_readdepths = [np.mean(x.reshape(-1, 100), axis=1) for x in data_list_readdepths]
data_list_readdepths = np.asarray(data_list_readdepths)





''' 
CNVNATOR PREDS: nan -> 0
                <DUP> -> 1
                <DEL> -> 2
XHMM PREDS: 'DEL' -> 0
            'DUP' -> 1
'''

# pdb.set_trace()
# data_list_cnvnator_preds[data_list_cnvnator_preds == 'nan'] = 0
# data_list_cnvnator_preds[data_list_cnvnator_preds == "'<DUP>'"] = 1
# data_list_cnvnator_preds[data_list_cnvnator_preds == "'<DEL>'"] = 2

pdb.set_trace()

# data_list_xhmm_preds[data_list_xhmm_preds == "'DEL'"] = 0
# data_list_xhmm_preds[data_list_xhmm_preds == "'DUP'"] = 1

# data_list_xhmm_preds = to_categorical(data_list_xhmm_preds, num_classes =2)
# data_list_cnvnator_preds = to_categorical(data_list_cnvnator_preds, num_classes =3)

data_list_readdepths = np.expand_dims(data_list_readdepths, axis=2)
#normalize a bit.
data_list_readdepths = data_list_readdepths  #/45000


print("Read depths data matrix shape: ", data_list_readdepths.shape)
print("Freec predictions data matrix shape: ", data_list_freec_preds.shape)
print("Canavar predictions (labels) data matrix shape: ", data_list_canavar_preds.shape)

'''
input1 <- data_list_xhmm_preds
input2 <- data_list_readdepths
labels <- data_list_cnvnator_preds
'''

#model
max_length =  192000 # maximum length of read depth signals
inpsize = max_length / 100

input1 = Input(shape=(1,)) # freec prediction
input2 = Input(shape=(inpsize,1)) # read depth sequence
masked_input2 = Masking(mask_value = -1)(input2)
features1 = BatchNormalization()(masked_input2)
features2 = Bidirectional(LSTM(128))(features1)
features3 = BatchNormalization()(features2)
merged = concatenate([features3, input1])
features4 = Dense(100, activation='relu')(merged)
output = Dense(1,activation='relu')(features4)

model = Model(inputs=[input1, input2], outputs = output)
print(model.summary())
#comment

#train - test split
data_list_freec_preds_train, data_list_freec_preds_test, \
data_list_readdepths_train, data_list_readdepths_test, \
data_list_canavar_preds_train, data_list_canavar_preds_test = train_test_split(data_list_freec_preds, data_list_readdepths, data_list_canavar_preds, test_size=0.1, random_state=35)

np.save('./outputs/data_list_freec_preds_test.npy', data_list_freec_preds_test)
np.save('./outputs/data_list_readdepths_test.npy', data_list_readdepths_test)
np.save('./outputs/data_list_canavar_preds_test.npy', data_list_canavar_preds_test)



model.compile(loss='mean_absolute_error', optimizer='adam')
model.fit([data_list_freec_preds_train, data_list_readdepths_train], data_list_canavar_preds_train, validation_split = 0.2, epochs = 60, batch_size=512)

model.save('./outputs/deepXCNVfreec_batchnorm_bilstm128_batchnorm_dense100_dense1_bs256_padding-1_60epochs_traintestsplitted_mae.h5')

